import time
import torch
import sys
from util.trainer.AverageMeter import AverageMeter
from util.trainer.accuracy import accuracy
from util.trainer.loss_and_top1_acc import loss_and_top1_acc
from util.landscape.util import seconds2days_hours_minutes_seconds
def train(train_loader, model, criterion, optimizer, epoch, args, local_rank, rank):
    """
        Run one train epoch
    """
    batch_time = AverageMeter()
    data_time = AverageMeter()
    loss_per_gpu = AverageMeter()
    top1_acc_per_gpu = AverageMeter()

    create_graph = args.opt == 'adahessian'
    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        target_var = target.cuda(local_rank)
        input_var = input.cuda(local_rank)
        # if args.half:
        #     input_var = input_var.half()

        
        if args.show_mem_per_gpu and (rank==0 or rank is None):
            torch.cuda.reset_peak_memory_stats(args.device)

        output = model(input_var)
        loss = criterion(output, target_var)
        
        if args.show_mem_per_gpu and (rank==0 or rank is None):
            memory_used = torch.cuda.max_memory_allocated(args.device) / (1024 ** 3)
            print(f"Memory used per GPU for feedforward: {memory_used:.2f} GB")
            sys.stdout.flush()
            args.show_mem_per_gpu = False

        # compute gradient and do SGD step
        optimizer.zero_grad()
        if args.opt == 'KO':
            if args.k_version >= 2.0:
                optimizer.step(loss)
            else:
                loss.backward()
                optimizer.step(loss)
        elif args.opt == 'lbfgs':
            def closure():
                optimizer.zero_grad()
                output = model(input_var)
                loss = criterion(output, target_var)
                loss.backward()
                return loss
            optimizer.step(closure)
        else:
            if args.opt in ['kfac', 'ekfac'] and optimizer.steps % optimizer.TCov == 0:
                # compute true fisher
                optimizer.acc_stats = True
                with torch.no_grad():
                    sampled_y = torch.multinomial(torch.nn.functional.softmax(output.cpu().data, dim=1),
                                                1).squeeze().cuda()
                loss_sample = criterion(output, sampled_y)
                loss_sample.backward(retain_graph=True)
                optimizer.acc_stats = False
                optimizer.zero_grad()  # clear the gradient for computing true-fisher.
            loss.backward(create_graph=create_graph)
            optimizer.step()

        output = output.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target_var)[0]
        loss_per_gpu.update(loss.item(), input_var.size(0))
        top1_acc_per_gpu.update(prec1.item(), input_var.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        if (rank==0 or (rank is None)) and i==10: 
            print(f'Estimated remaining time for this training is : {seconds2days_hours_minutes_seconds(batch_time.avg*len(train_loader)*(args.epochs-epoch))}')
        end = time.time()
    Loss, top1_acc = loss_and_top1_acc(loss_per_gpu,top1_acc_per_gpu, local_rank)

    if rank==0 or (rank is None):    
        print('Epoch: [{0}]'.format(epoch))
        print('train:     \t'
                'Loss {loss:.4f}\t'
                'Prec@1 {top1_acc:.3f}'.format(
                loss=Loss, top1_acc=top1_acc))
